{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Loading and saving data between StellarGraph and Neo4j\n",
    "\n",
    "> This demo explains how to load data from Neo4j into a form that can be used by the StellarGraph library, and how to save predictions back into the database. [See all other demos](../README.md).\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/basics/loading-saving-neo4j.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/basics/loading-saving-neo4j.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[The StellarGraph library](https://github.com/stellargraph/stellargraph) supports loading graph information from Neo4j. [Neo4j](https://neo4j.com) is a popular graph database.\n",
    "\n",
    "If your data is already in Neo4j, this is a great way to load it. If not, [loading via another route](README.md) is likely to be faster and potentially more convenient.\n",
    "\n",
    "This notebook demonstrates one approach to connecting StellarGraph and Neo4j. It uses the SQL-like [Cypher language](https://neo4j.com/developer/cypher-query-language/) to read a graph or subgraph from Neo4j into [Pandas](https://pandas.pydata.org) DataFrames, and then uses these to construct a `StellarGraph` object (following the same techniques as in the [loading via Pandas](loading-pandas.ipynb) demo, which has more details about that aspect). This notebook assumes some familiarity with Cypher constructs like `MATCH`, `RETURN` and `WHERE`. This notebook uses the [Py2neo](http://py2neo.org/) library to interact with a Neo4j instance.\n",
    "\n",
    "> StellarGraph also has experimental support for [running some algorithms directly using Neo4j](../connector/neo4j/README.md).\n",
    "\n",
    "This notebook walks through scenarios for loading and storing graphs.\n",
    "\n",
    "- homogeneous graph without features (a homogeneous graph is one with only one type of node and one type of edge)\n",
    "- homogeneous graph with features\n",
    "- homogeneous graph with edge weights\n",
    "- directed graphs (a graph is directed if edges have a \"start\" and \"end\" nodes, instead of just connecting two nodes)\n",
    "- heterogeneous graphs (more than one node type and/or more than one edge type) with and without node features or edge weights, this includes knowledge graphs\n",
    "- subgraphs (an example of filtering which nodes and edges are loaded)\n",
    "- saving predictions into Neo4j\n",
    "\n",
    "> StellarGraph supports loading data from many sources with all sorts of data preprocessing, via [Pandas](https://pandas.pydata.org) DataFrames, [NumPy](https://www.numpy.org) arrays, [Neo4j](https://neo4j.com) and [NetworkX](https://networkx.github.io) graphs. See [all loading demos](README.md) for more details.\n",
    "\n",
    "The `StellarGraph` class is available at the top level of the `stellargraph` library:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "outputs": [],
   "source": [
    "# install StellarGraph if running on Google Colab\n",
    "import sys\n",
    "if 'google.colab' in sys.modules:\n",
    "  %pip install -q stellargraph[demos]==1.2.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "VersionCheck"
    ]
   },
   "outputs": [],
   "source": [
    "# verify that we're using the correct version of StellarGraph for this notebook\n",
    "import stellargraph as sg\n",
    "\n",
    "try:\n",
    "    sg.utils.validate_notebook_version(\"1.2.1\")\n",
    "except AttributeError:\n",
    "    raise ValueError(\n",
    "        f\"This notebook requires StellarGraph version 1.2.1, but a different version {sg.__version__} is installed.  Please see <https://github.com/stellargraph/stellargraph/issues/1172>.\"\n",
    "    ) from None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from stellargraph import StellarGraph"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Connecting to Neo4j\n",
    "\n",
    "To read anything from Neo4j, we'll need a connection to a running instance."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import py2neo\n",
    "\n",
    "default_host = os.environ.get(\"STELLARGRAPH_NEO4J_HOST\")\n",
    "\n",
    "# Create the Neo4j Graph database object; the parameters can be edited to specify location and authentication\n",
    "neo4j_graph = py2neo.Graph(host=default_host, port=None, user=None, password=None)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Dataset\n",
    "\n",
    "We'll be working with a graph representing a square with a diagonal. We'll give the `a` node label `foo` and the other nodes the label `bar`, along with some features. We'll also give each edge a label matching its orientation and a weight.\n",
    "\n",
    "```\n",
    "a -- b\n",
    "| \\  |\n",
    "|  \\ |\n",
    "d -- c\n",
    "```\n",
    "\n",
    "This section uses the types from `py2neo` to seed our Neo4j instance with the example data. For real work involving StellarGraph and Neo4j, the real data would be loaded into the database via some external process. However, we need some data to work with for this demo and so we need to have the cells in this section. They can be **safely ignored**, and removed for real work."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "from py2neo.data import Node, Relationship, Subgraph\n",
    "\n",
    "a = Node(\"foo\", name=\"a\", top=True, left=True, foo_numbers=[0.1, 0.2, 0.3])\n",
    "b = Node(\"bar\", name=\"b\", top=True, left=False, bar_numbers=[1, -2])\n",
    "c = Node(\"bar\", name=\"c\", top=False, left=False, bar_numbers=[34, 5.6])\n",
    "d = Node(\"bar\", name=\"d\", top=False, left=True, bar_numbers=[0.7, -98])\n",
    "\n",
    "ab = Relationship(a, \"horizontal\", b, weight=1.0)\n",
    "bc = Relationship(b, \"vertical\", c, weight=0.2)\n",
    "cd = Relationship(c, \"horizontal\", d, weight=3.4)\n",
    "da = Relationship(d, \"vertical\", a, weight=5.67)\n",
    "ac = Relationship(a, \"diagonal\", c, weight=1.0)\n",
    "\n",
    "subgraph = Subgraph([a, b, c, d], [ab, bc, cd, da, ac])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We don't want to accidentally overwrite or delete important data or add junk in a production Neo4j instance. As a check, this demo requires the Neo4j instance to be empty. If the `neo4j_graph` connection is to a non-empty database, please either:\n",
    "\n",
    "- delete everything from it (there's a cell at the end of the notebook that can be used, if that's ok)\n",
    "- start a new instance, adjust the parameters to `py2neo.Graph` above to connect to it, and rerun the cells from there"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_nodes = len(neo4j_graph.nodes)\n",
    "num_relationships = len(neo4j_graph.relationships)\n",
    "if num_nodes > 0 or num_relationships > 0:\n",
    "    raise ValueError(\n",
    "        f\"neo4j_graphdb: expected an empty database to give a reliable result and to avoid corrupting your data with mutations & the `delete_all` in the last cell, found {num_nodes} nodes and {num_relationships} relationships in the database already\"\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we can fill the database by writing our example data to the database."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "neo4j_graph.create(subgraph)\n",
    "\n",
    "# basic check that the database has the right data\n",
    "assert len(neo4j_graph.nodes) == 4\n",
    "assert len(neo4j_graph.relationships) == 5"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Homogeneous graph without features (edges only)\n",
    "\n",
    "We'll start with a homogeneous graph without any node features. This means the graph consists of only nodes and edges without any information other than a unique identifier. To simulate this, we will be ignoring all of the properties we added except the `name` property, which is a unique identifier for each node.\n",
    "\n",
    "We can use a single Cypher query to retrieve the identifiers for the source and target of each edge. We're using `name` as the identifier here, and each application should choose an appropriate identifier, such as the `id(...)` ([docs](https://neo4j.com/docs/cypher-manual/current/functions/scalar/#functions-id)) if [the dangers of ID reuse](https://neo4j.com/docs/cypher-manual/current/clauses/match/#match-node-by-id) don't apply.\n",
    "\n",
    "We can execute a Cypher query using the `run` method ([docs](https://py2neo.org/v4/database.html#py2neo.database.Graph.run)) of `py2neo.Graph`, which returns a `Cursor` object that has a `to_data_frame` method ([docs](https://py2neo.org/v4/database.html#py2neo.database.Cursor.to_data_frame)) to convert the results to a columnar DataFrame. `StellarGraph` type expects the columns for the nodes in an edge to be called `source` and `target` by default, so the query uses an `AS` to ensure the DataFrame columns match those defaults."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>source</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>d</td>\n",
       "      <td>a</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>a</td>\n",
       "      <td>b</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>b</td>\n",
       "      <td>c</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>a</td>\n",
       "      <td>c</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>c</td>\n",
       "      <td>d</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  source target\n",
       "0      d      a\n",
       "1      a      b\n",
       "2      b      c\n",
       "3      a      c\n",
       "4      c      d"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "edges = neo4j_graph.run(\n",
    "    \"\"\"\n",
    "    MATCH (s) --> (t) \n",
    "    RETURN s.name AS source, t.name AS target\n",
    "    \"\"\"\n",
    ").to_data_frame()\n",
    "edges.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now have a DataFrame where each row represents an edge in the graph, which is exactly the format expected by the `StellarGraph` constructor ([docs](https://stellargraph.readthedocs.io/en/stable/api.html#stellargraph.StellarGraph)). We can pass the DataFrame as the `edges` parameter:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "edges_only = StellarGraph(edges=edges)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The `info` method ([docs](https://stellargraph.readthedocs.io/en/stable/api.html#stellargraph.StellarGraph.info)) gives a high-level summary of a `StellarGraph`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  default: [4]\n",
      "    Features: none\n",
      "    Edge types: default-default->default\n",
      "\n",
      " Edge types:\n",
      "    default-default->default: [5]\n",
      "        Weights: all 1 (default)\n"
     ]
    }
   ],
   "source": [
    "print(edges_only.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "On this square, it tells us that there's 4 nodes of type `default` (a homogeneous graph still has node and edge types, but they default to `default`), with no features, and one type of edge between them. It also tells us that there's 5 edges of type `default` that go between nodes of type `default`. This matches what we expect: it's a graph with 4 nodes and 5 edges and one type of each."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Homogeneous graph with features\n",
    "\n",
    "For many real-world problems, we have more than just graph structure: we have information about the nodes and edges. For instance, we might have a graph of academic papers (nodes) and how they cite each other (edges): we might have information about the nodes such as the authors and the publication year, and even the abstract or full paper contents. If we're doing a machine learning task, it can be useful to feed this information into models. The `StellarGraph` class supports this using another Pandas DataFrame: each row corresponds to a feature vector for a node.\n",
    "\n",
    "We can create an appropriate DataFrame in the same way as we created the edges one, with a Cypher query that selects the relevant information. In this case, we need the `name` to match the rows of features to their node, and we're also going to have 3 features:\n",
    "\n",
    "- the `top` and `left` properties from each node as two of a features\n",
    "- whether the `bar_numbers` property exists on the node using the `exists` function ([docs](https://neo4j.com/docs/cypher-manual/current/functions/predicate/#functions-exists)): this is a demonstration that features don't have to be just properties, but can be calculated with any computation supported by Neo4j"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>n.top</th>\n",
       "      <th>n.left</th>\n",
       "      <th>exists(n.bar_numbers)</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>a</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>b</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>c</td>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>d</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  name  n.top  n.left  exists(n.bar_numbers)\n",
       "0    a   True    True                  False\n",
       "1    b   True   False                   True\n",
       "2    c  False   False                   True\n",
       "3    d  False    True                   True"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_homogeneous_nodes = neo4j_graph.run(\n",
    "    \"\"\"\n",
    "    MATCH (n) \n",
    "    RETURN n.name AS name, n.top, n.left, exists(n.bar_numbers)\n",
    "    \"\"\"\n",
    ").to_data_frame()\n",
    "\n",
    "raw_homogeneous_nodes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`StellarGraph` uses the index of the DataFrame as the connection between a node and a row of the DataFrame. Currently our dataframe just has a simple numeric range as the index, but it needs to be using the `name` column. Pandas offers [a few ways to control the indexing](https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#set-reset-index); in this case, we want to replace the current index by moving the `name` column to it, which is done most easily with `set_index`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>n.top</th>\n",
       "      <th>n.left</th>\n",
       "      <th>exists(n.bar_numbers)</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>name</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>a</th>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>b</th>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>c</th>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>d</th>\n",
       "      <td>False</td>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      n.top  n.left  exists(n.bar_numbers)\n",
       "name                                      \n",
       "a      True    True                  False\n",
       "b      True   False                   True\n",
       "c     False   False                   True\n",
       "d     False    True                   True"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "homogeneous_nodes = raw_homogeneous_nodes.set_index(\"name\")\n",
    "homogeneous_nodes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We've now got all the right node data, in addition to the edges from before, so now we can create a `StellarGraph`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  default: [4]\n",
      "    Features: float32 vector, length 3\n",
      "    Edge types: default-default->default\n",
      "\n",
      " Edge types:\n",
      "    default-default->default: [5]\n",
      "        Weights: all 1 (default)\n"
     ]
    }
   ],
   "source": [
    "homogeneous = StellarGraph(homogeneous_nodes, edges)\n",
    "print(homogeneous.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice the output of `info` now says that the nodes of the `default` type have 3 features."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Homogeneous graph with edge weights\n",
    "\n",
    "Some algorithms can understand edge weights, which can be used as a measure of the strength of the connection, or a measure of distance between nodes. A `StellarGraph` instance can have weighted edges, by including a `weight` column in the DataFrame of edges.\n",
    "\n",
    "We can extend our Cypher query that loads the edge sources and targets to also load the `weight` property. As with node features, we could any computation supported by Neo4j to calculate the weight, beyond just accessing a property as we do here."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>source</th>\n",
       "      <th>target</th>\n",
       "      <th>weight</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>d</td>\n",
       "      <td>a</td>\n",
       "      <td>5.67</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>a</td>\n",
       "      <td>b</td>\n",
       "      <td>1.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>b</td>\n",
       "      <td>c</td>\n",
       "      <td>0.20</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>a</td>\n",
       "      <td>c</td>\n",
       "      <td>1.00</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>c</td>\n",
       "      <td>d</td>\n",
       "      <td>3.40</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  source target  weight\n",
       "0      d      a    5.67\n",
       "1      a      b    1.00\n",
       "2      b      c    0.20\n",
       "3      a      c    1.00\n",
       "4      c      d    3.40"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "weighted_edges = neo4j_graph.run(\n",
    "    \"\"\"\n",
    "    MATCH (s) -[r]-> (t) \n",
    "    RETURN s.name AS source, t.name AS target, r.weight AS weight\n",
    "    \"\"\"\n",
    ").to_data_frame()\n",
    "weighted_edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  default: [4]\n",
      "    Features: float32 vector, length 3\n",
      "    Edge types: default-default->default\n",
      "\n",
      " Edge types:\n",
      "    default-default->default: [5]\n",
      "        Weights: range=[0.2, 5.67], mean=2.254, std=2.25534\n"
     ]
    }
   ],
   "source": [
    "weighted_homogeneous = StellarGraph(homogeneous_nodes, weighted_edges)\n",
    "print(weighted_homogeneous.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notice the output of `info` now shows additional statistics about edge weights."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Directed graphs\n",
    "\n",
    "Some graphs have edge directions, where going from source to target has a different meaning to going from target to source.\n",
    "\n",
    "A directed graph can be created by using the `StellarDiGraph` class instead of the `StellarGraph` one. The construction is almost identical, and we can reuse any of the DataFrames that we created in the sections above. For instance, continuing from the previous cell, we can have a directed homogeneous graph with node features and edge weights."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarDiGraph: Directed multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  default: [4]\n",
      "    Features: float32 vector, length 3\n",
      "    Edge types: default-default->default\n",
      "\n",
      " Edge types:\n",
      "    default-default->default: [5]\n",
      "        Weights: range=[0.2, 5.67], mean=2.254, std=2.25534\n"
     ]
    }
   ],
   "source": [
    "from stellargraph import StellarDiGraph\n",
    "\n",
    "directed_weighted_homogeneous = StellarDiGraph(homogeneous_nodes, weighted_edges)\n",
    "print(directed_weighted_homogeneous.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Heterogeneous graphs\n",
    "\n",
    "Some graphs have multiple types of nodes and multiple types of edges. Each type might have different data associated with it.\n",
    "\n",
    "For example, an academic citation network that includes authors might have `wrote` edges connecting `author` nodes to `paper` nodes, in addition to the `cites` edges between `paper` nodes. There could be `supervised` edges between `author`s ([example](https://academictree.org)) too, or any number of additional node and edge types. A knowledge graph (aka RDF, triple stores or knowledge base) is an extreme form of an heterogeneous graph, with dozens, hundreds or even thousands of edge (or relation) types. Typically in a knowledge graph, edges and their types represent the information associated with a node, rather than node features.\n",
    "\n",
    "`StellarGraph` supports all forms of heterogeneous graphs.\n",
    "\n",
    "A heterogeneous `StellarGraph` can be constructed in a similar way to a homogeneous graph, except we pass a dictionary with multiple elements instead of a single element like we did for the Cora examples in the \"homogeneous graph with features\" section and others above. For a heterogeneous graph, a dictionary has to be passed; passing a single DataFrame does not work.\n",
    "\n",
    "### Multiple node types\n",
    "\n",
    "The nodes of our square graph were given labels when we created them: `a` is of type `foo`, but `b`, `c` and `d` are of type `bar`. The `foo` node has an attribute `foo_numbers` that is a list/vector of numbers, and similarly the `bar` nodes has `bar_numbers`. These vectors might be some sort of summary of text associated with each node, or any other precomputed information about the node to use as input to our machine learning algorithm.\n",
    "\n",
    "The two types have properties with different names, and, they have different lengths: the `foo` node has a list of length 3, while all of the `bar` nodes have a list of length 2. We will load them into separate DataFrames with separate Cypher queries, first finding the node(s) of type `foo` and their properties, and then the same for the nodes of type `bar`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>name</th>\n",
       "      <th>numbers</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>a</td>\n",
       "      <td>[0.1, 0.2, 0.3]</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  name          numbers\n",
       "0    a  [0.1, 0.2, 0.3]"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_foo_nodes = neo4j_graph.run(\n",
    "    \"\"\"\n",
    "    MATCH (n:foo) \n",
    "    RETURN n.name AS name, n.foo_numbers AS numbers\n",
    "    \"\"\"\n",
    ").to_data_frame()\n",
    "raw_foo_nodes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this case, our features are more complicated than just independent booleans that can become columns; instead we have a list that we need to turn into individual columns. One way is by converting the list column to a list of lists, and using Pandas's constructor to convert this back to a DataFrame. We can set the index directly with this technique, and do not need to separately use `set_index`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "      <th>2</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>name</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>a</th>\n",
       "      <td>0.1</td>\n",
       "      <td>0.2</td>\n",
       "      <td>0.3</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "        0    1    2\n",
       "name               \n",
       "a     0.1  0.2  0.3"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "foo_nodes = pd.DataFrame(raw_foo_nodes[\"numbers\"].tolist(), index=raw_foo_nodes[\"name\"])\n",
    "foo_nodes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We've now got a DataFrame with 3 columns of numbers, as required!\n",
    "\n",
    "We can do the same for the nodes of type `bar` to get a DataFrame with 2 columns of numbers:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>0</th>\n",
       "      <th>1</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>name</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>b</th>\n",
       "      <td>1.0</td>\n",
       "      <td>-2.0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>c</th>\n",
       "      <td>34.0</td>\n",
       "      <td>5.6</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>d</th>\n",
       "      <td>0.7</td>\n",
       "      <td>-98.0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "         0     1\n",
       "name            \n",
       "b      1.0  -2.0\n",
       "c     34.0   5.6\n",
       "d      0.7 -98.0"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_bar_nodes = neo4j_graph.run(\n",
    "    \"\"\"\n",
    "    MATCH (n:bar) \n",
    "    RETURN n.name AS name, n.bar_numbers AS numbers\n",
    "    \"\"\"\n",
    ").to_data_frame()\n",
    "\n",
    "bar_nodes = pd.DataFrame(raw_bar_nodes[\"numbers\"].tolist(), index=raw_bar_nodes[\"name\"])\n",
    "bar_nodes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We have the information for the two node types `foo` and `bar` in separate DataFrames, so we can now put them in a dictionary to create a `StellarGraph`. Notice that `info()` is now reporting multiple node types, as well as information specific to each."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  bar: [3]\n",
      "    Features: float32 vector, length 2\n",
      "    Edge types: bar-default->bar, bar-default->foo\n",
      "  foo: [1]\n",
      "    Features: float32 vector, length 3\n",
      "    Edge types: foo-default->bar\n",
      "\n",
      " Edge types:\n",
      "    foo-default->bar: [2]\n",
      "        Weights: all 1 (default)\n",
      "    bar-default->bar: [2]\n",
      "        Weights: all 1 (default)\n",
      "    bar-default->foo: [1]\n",
      "        Weights: all 1 (default)\n"
     ]
    }
   ],
   "source": [
    "heterogeneous_nodes = StellarGraph({\"foo\": foo_nodes, \"bar\": bar_nodes}, edges)\n",
    "print(heterogeneous_nodes.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Multiple edge types\n",
    "\n",
    "Graphs with multiple edge types are simpler. Since we have no features on the edges, we can pass a DataFrame with an additional column for the type, specifying it via the `edge_type_column` parameter. (Multiple edge types can also be created in the same way as multiple node types, by passing with a dictionary of DataFrames, but this is not necessary.)\n",
    "\n",
    "For example, our square graph has labelled each edge with its orientation. We can retrieve this using the `type` function ([docs](https://neo4j.com/docs/cypher-manual/current/functions/scalar/#functions-type)) to get a DataFrame with a label column too."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>source</th>\n",
       "      <th>target</th>\n",
       "      <th>label</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>d</td>\n",
       "      <td>a</td>\n",
       "      <td>vertical</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>a</td>\n",
       "      <td>b</td>\n",
       "      <td>horizontal</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>b</td>\n",
       "      <td>c</td>\n",
       "      <td>vertical</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>a</td>\n",
       "      <td>c</td>\n",
       "      <td>diagonal</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>c</td>\n",
       "      <td>d</td>\n",
       "      <td>horizontal</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  source target       label\n",
       "0      d      a    vertical\n",
       "1      a      b  horizontal\n",
       "2      b      c    vertical\n",
       "3      a      c    diagonal\n",
       "4      c      d  horizontal"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labelled_edges = neo4j_graph.run(\n",
    "    \"\"\"\n",
    "    MATCH (s) -[r]-> (t) \n",
    "    RETURN s.name AS source, t.name AS target, type(r) AS label\n",
    "    \"\"\"\n",
    ").to_data_frame()\n",
    "\n",
    "labelled_edges"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We now have a dictionary of the edges, so we can create a graph with one node type, but multiple edge types. Notice how `info()` shows 3 edge types."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  default: [4]\n",
      "    Features: none\n",
      "    Edge types: default-diagonal->default, default-horizontal->default, default-vertical->default\n",
      "\n",
      " Edge types:\n",
      "    default-vertical->default: [2]\n",
      "        Weights: all 1 (default)\n",
      "    default-horizontal->default: [2]\n",
      "        Weights: all 1 (default)\n",
      "    default-diagonal->default: [1]\n",
      "        Weights: all 1 (default)\n"
     ]
    }
   ],
   "source": [
    "hetereogeneous_edges = StellarGraph(edges=labelled_edges, edge_type_column=\"label\")\n",
    "print(hetereogeneous_edges.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The edges can be weighted if desired.\n",
    "\n",
    "`StellarGraph` supports multiple node types and multiple edge types at the same time:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 4, Edges: 5\n",
      "\n",
      " Node types:\n",
      "  bar: [3]\n",
      "    Features: float32 vector, length 2\n",
      "    Edge types: bar-diagonal->foo, bar-horizontal->bar, bar-horizontal->foo, bar-vertical->bar, bar-vertical->foo\n",
      "  foo: [1]\n",
      "    Features: float32 vector, length 3\n",
      "    Edge types: foo-diagonal->bar, foo-horizontal->bar, foo-vertical->bar\n",
      "\n",
      " Edge types:\n",
      "    foo-horizontal->bar: [1]\n",
      "        Weights: all 1 (default)\n",
      "    foo-diagonal->bar: [1]\n",
      "        Weights: all 1 (default)\n",
      "    bar-vertical->foo: [1]\n",
      "        Weights: all 1 (default)\n",
      "    bar-vertical->bar: [1]\n",
      "        Weights: all 1 (default)\n",
      "    bar-horizontal->bar: [1]\n",
      "        Weights: all 1 (default)\n"
     ]
    }
   ],
   "source": [
    "hetereogeneous_everything = StellarGraph(\n",
    "    {\"foo\": foo_nodes, \"bar\": bar_nodes}, labelled_edges, edge_type_column=\"label\"\n",
    ")\n",
    "print(hetereogeneous_everything.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Subgraphs\n",
    "\n",
    "In many cases, one wants to work with only a subgraph of the data that is stored in Neo4j. For example:\n",
    "\n",
    "- only some node and edges that are interesting for the model, so one can avoid transferring data unnecessarily by filtering in the database\n",
    "- there's only a small amount of data with labels for machine learning, so again one can reduce how much data is transferred\n",
    "- it's faster and easier to explore and experiment with a smaller version of a huge graph\n",
    "\n",
    "The Cypher queries we're using to load our data can be extended to do these. \n",
    "\n",
    "### Node/edge filtering\n",
    "\n",
    "One type of subgraph in which someone might be interested is one where the nodes and/or edges satisfy certain criteria. This can be done by applying filters like a `WHERE` clause ([docs](https://neo4j.com/docs/cypher-manual/current/clauses/where/)) to the Cypher queries.\n",
    "\n",
    "For instance, maybe we only want to load nodes that are either on the left of the square or on the bottom or both (meaning, not `b`, which is the top right corner)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>n.left</th>\n",
       "      <th>n.top</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>name</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>a</th>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>c</th>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>d</th>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      n.left  n.top\n",
       "name               \n",
       "a       True   True\n",
       "c      False  False\n",
       "d       True  False"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "raw_subgraph_nodes = neo4j_graph.run(\n",
    "    \"\"\"\n",
    "    MATCH (n) \n",
    "    WHERE n.left OR NOT n.top\n",
    "    RETURN n.name AS name, n.left, n.top\n",
    "    \"\"\"\n",
    ").to_data_frame()\n",
    "\n",
    "subgraph_nodes = raw_subgraph_nodes.set_index(\"name\")\n",
    "subgraph_nodes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We've got a set of nodes, and we now need the edges that connect these nodes, and only these nodes. We should not have any edges that involve nodes we didn't select. For our example, that means we need to find the 3 edges between the `a`, `c` and `d` nodes, and avoid the `a`-`b` and `b`-`c` edges.\n",
    "\n",
    "Some ways to do this are to start with the query for all edges and add a `WHERE` clause to filter to the nodes of interest, which might be done in two ways:\n",
    "\n",
    "- pass the identifiers for the selected nodes as parameters into the queries and perform a match with `IN` against the identifiers\n",
    "- reproduce the same filtering on the source and target nodes of each edge\n",
    "\n",
    "The first option can look something like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>source</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>d</td>\n",
       "      <td>a</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>a</td>\n",
       "      <td>c</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>c</td>\n",
       "      <td>d</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  source target\n",
       "0      d      a\n",
       "1      a      c\n",
       "2      c      d"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "subgraph_edges = neo4j_graph.run(\n",
    "    \"\"\"\n",
    "    MATCH (s) -[r]-> (t)\n",
    "    WHERE s.name IN $node_names AND t.name IN $node_names\n",
    "    RETURN s.name AS source, t.name AS target\n",
    "    \"\"\",\n",
    "    {\"node_names\": list(subgraph_nodes.index)},\n",
    ").to_data_frame()\n",
    "\n",
    "subgraph_edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 3, Edges: 3\n",
      "\n",
      " Node types:\n",
      "  default: [3]\n",
      "    Features: float32 vector, length 2\n",
      "    Edge types: default-default->default\n",
      "\n",
      " Edge types:\n",
      "    default-default->default: [3]\n",
      "        Weights: all 1 (default)\n"
     ]
    }
   ],
   "source": [
    "subgraph = StellarGraph(subgraph_nodes, subgraph_edges)\n",
    "print(subgraph.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The second option can look something like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>source</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>d</td>\n",
       "      <td>a</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>a</td>\n",
       "      <td>c</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>c</td>\n",
       "      <td>d</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  source target\n",
       "0      d      a\n",
       "1      a      c\n",
       "2      c      d"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "subgraph_edges_refilter = neo4j_graph.run(\n",
    "    \"\"\"\n",
    "    MATCH (s) -[r]-> (t)\n",
    "    WHERE (s.left OR NOT s.top) AND (t.left OR NOT t.top)\n",
    "    RETURN s.name AS source, t.name AS target\n",
    "    \"\"\"\n",
    ").to_data_frame()\n",
    "\n",
    "subgraph_edges_refilter"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 3, Edges: 3\n",
      "\n",
      " Node types:\n",
      "  default: [3]\n",
      "    Features: float32 vector, length 2\n",
      "    Edge types: default-default->default\n",
      "\n",
      " Edge types:\n",
      "    default-default->default: [3]\n",
      "        Weights: all 1 (default)\n"
     ]
    }
   ],
   "source": [
    "subgraph_refilter = StellarGraph(subgraph_nodes, subgraph_edges_refilter)\n",
    "print(subgraph_refilter.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Similar filtering can be applied to edges, such as only including edges with specific types or anything more complicated than that. This can happen in addition to any node filtering, by expanding the `WHERE` clause in the edge query to filter based on the source and target nodes and on whatever criteria one has chosen for edges."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### k-Hop subgraphs\n",
    "\n",
    "Another sort of subgraph in which one might be interested is a \"k-hop\" subgraph of a set of start nodes. This refers to all nodes where the length of the path (number of edges) to a start node is at most `k`. For example, the 1-hop subgraph around `b` in the square is nodes `a`, `b` and `c`, because the shortest path from `b` to `d` is two edges.\n",
    "\n",
    "Many graph machine learning algorithms only use a small neighbourhood of a node for influencing the predictions of the model, commonly in the form of its 1-, 2- or 3-hop subgraph. If we're only interested in feeding small groups of nodes into a model, we can work with just the neighbourhoods of those nodes and avoid loading the rest of the potentially-large graph. This might apply in cases like:\n",
    "\n",
    "- only a small number of nodes have ground-truth labels for training a model\n",
    "- a trained model is being used to predict on only a small group of nodes of interest\n",
    "\n",
    "For many cases, the nodes in the subgraph can be calculated a Cypher query with [a variable length relationship constraint](https://neo4j.com/docs/cypher-manual/current/clauses/match/#varlength-rels). For instance, if we're computing the 1-hop subgraph around the `b` node, we might do something like the following cell. Some notes about it:\n",
    "\n",
    "- the `*0..1` means a path of 0 to 1 edges; [the 0 is important](https://neo4j.com/docs/cypher-manual/current/clauses/match/#zero-length-paths) to make sure we include the `b` node in the final subgraph too, for a 2-hop subgraph, this should be `(start) -[*0..2]- (n)`\n",
    "- it uses a list to easily support using multiple start nodes, which will be more common"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>n.top</th>\n",
       "      <th>n.left</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>name</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>b</th>\n",
       "      <td>True</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>a</th>\n",
       "      <td>True</td>\n",
       "      <td>True</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>c</th>\n",
       "      <td>False</td>\n",
       "      <td>False</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "      n.top  n.left\n",
       "name               \n",
       "b      True   False\n",
       "a      True    True\n",
       "c     False   False"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "start_nodes = [\"b\"]\n",
    "\n",
    "raw_hop_nodes = neo4j_graph.run(\n",
    "    \"\"\"\n",
    "    MATCH (start) -[*0..1]- (n)\n",
    "    WHERE start.name IN $start_nodes\n",
    "    WITH DISTINCT n\n",
    "    RETURN n.name AS name, n.top, n.left\n",
    "    \"\"\",\n",
    "    {\"start_nodes\": start_nodes},\n",
    ").to_data_frame()\n",
    "\n",
    "hop_nodes = raw_hop_nodes.set_index(\"name\")\n",
    "hop_nodes"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once we've got the nodes, we can do the same process as in the previous section to get the edges between the nodes."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>source</th>\n",
       "      <th>target</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>a</td>\n",
       "      <td>b</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>b</td>\n",
       "      <td>c</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>a</td>\n",
       "      <td>c</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  source target\n",
       "0      a      b\n",
       "1      b      c\n",
       "2      a      c"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "hop_edges = neo4j_graph.run(\n",
    "    \"\"\"\n",
    "    MATCH (s) -[r]-> (t)\n",
    "    WHERE s.name IN $node_names AND t.name IN $node_names\n",
    "    RETURN s.name AS source, t.name AS target\n",
    "    \"\"\",\n",
    "    {\"node_names\": list(hop_nodes.index)},\n",
    ").to_data_frame()\n",
    "\n",
    "hop_edges"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "StellarGraph: Undirected multigraph\n",
      " Nodes: 3, Edges: 3\n",
      "\n",
      " Node types:\n",
      "  default: [3]\n",
      "    Features: float32 vector, length 2\n",
      "    Edge types: default-default->default\n",
      "\n",
      " Edge types:\n",
      "    default-default->default: [3]\n",
      "        Weights: all 1 (default)\n"
     ]
    }
   ],
   "source": [
    "hop_subgraph = StellarGraph(hop_nodes, hop_edges)\n",
    "print(subgraph.info())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One can expand the query to do more complicated computations, such as filtering which type of edges are included in the paths (like `[:horizontal*0..1]` to only follow horizontal edges), or which nodes are considered with `WHERE` clauses as in the previous section.\n",
    "\n",
    "The `apoc.path.subgraphNodes` function ([docs](https://neo4j.com/docs/labs/apoc/current/graph-querying/expand-subgraph-nodes/)) from [the APOC library](https://neo4j.com/docs/labs/apoc/4.0/introduction/) offers more control too."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Saving predictions into Neo4j\n",
    "\n",
    "Most graph machine learning tasks will end up with some sort of predictions about some set of nodes or links in the graph. For example, [a node classification task](../node-classification/gcn-node-classification.ipynb) might result in either predicted scores for a node into different classes, or even just the single class that is the most likely. The formats of these are usually:\n",
    "\n",
    "- scores: a multidimensional [NumPy](https://numpy.org) array. In the node classification example linked above, it's an array of floats of shape `(1, 2708, 7)`, where each of element along the axis of size 2708 represents a node, and the 7 numbers for that element represents the scores for each of the 7 classes for that node.\n",
    "- classes: a one-dimensional NumPy array. In the node classification example linked above, it's an array of strings of length 2708, where each element represents the predicted class for a node.\n",
    "\n",
    "For our graph, let's suppose we have finished predicting the class of a node, with three classes `X`, `Y` and `Z`, and now want to save them back into the Neo4j database to use for visualisation and downstream tasks. For this hypothetical example, we were only interested in predictions for nodes `a` and `b`.\n",
    "\n",
    "The result of the task and all post-processing might be something like:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "predicted_nodes = [\"a\", \"b\"]\n",
    "predicted_scores = np.array([[[0.1, 0.8, 0.1], [0.4, 0.35, 0.25]]])  # a  # b\n",
    "predicted_class = np.array([\"Y\", \"X\"])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We want to update the Neo4j database to hold the scores in a `predicted_class_scores` properties and the class itself in a `predicted_class` score for each of the nodes with predictions. This can be achieved with a parameterised Cypher query using `UNWIND` and `SET`. For this, we need to have the data as a sequence of one record for each node."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'name': 'a', 'scores': [0.1, 0.8, 0.1], 'class': 'Y'},\n",
       " {'name': 'b', 'scores': [0.4, 0.35, 0.25], 'class': 'X'}]"
      ]
     },
     "execution_count": 34,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "predictions = [\n",
    "    {\"name\": name, \"scores\": list(scores), \"class\": class_}\n",
    "    for name, scores, class_ in zip(predicted_nodes, predicted_scores[0], predicted_class)\n",
    "]\n",
    "predictions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can execute the query. The `UNWIND` means that `prediction` hold each of the dictionaries successively, for which we can find the relevant node and update its properties as desired."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [],
   "source": [
    "neo4j_graph.evaluate(\n",
    "    \"\"\"\n",
    "    UNWIND $predictions AS prediction\n",
    "    MATCH (n { name: prediction.name })\n",
    "    SET n.predicted_class_scores = prediction.scores\n",
    "    SET n.predicted_class = prediction.class\n",
    "    \"\"\",\n",
    "    {\"predictions\": predictions},\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To verify that this behaved as desired, let's read back all the nodes, to see that `a` and `b` were updated with the right information."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>n.name</th>\n",
       "      <th>n.predicted_class_scores</th>\n",
       "      <th>n.predicted_class</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>a</td>\n",
       "      <td>[0.1, 0.8, 0.1]</td>\n",
       "      <td>Y</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>b</td>\n",
       "      <td>[0.4, 0.35, 0.25]</td>\n",
       "      <td>X</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>c</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>d</td>\n",
       "      <td>None</td>\n",
       "      <td>None</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "  n.name n.predicted_class_scores n.predicted_class\n",
       "0      a          [0.1, 0.8, 0.1]                 Y\n",
       "1      b        [0.4, 0.35, 0.25]                 X\n",
       "2      c                     None              None\n",
       "3      d                     None              None"
      ]
     },
     "execution_count": 36,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "verification_data = neo4j_graph.run(\n",
    "    \"MATCH (n) RETURN n.name, n.predicted_class_scores, n.predicted_class\"\n",
    ").to_data_frame()\n",
    "\n",
    "verification_data.sort_values(\"n.name\")  # sort for ease of reference"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "This notebook demonstrated many ways to read data from Neo4j into a `StellarGraph` graph object, for many types of graphs:\n",
    "\n",
    "- with or without node features\n",
    "- with or without edge weights\n",
    "- directed or not\n",
    "- homogeneous or heterogeneous\n",
    "\n",
    "We used the `py2neo` library to run Cypher queries to create Pandas DataFrames, that we could load into `StellarGraph` objects. The process for loading from Pandas DataFrames is explored in more detail in the [loading via Pandas](loading-pandas.ipynb) demonstration, that has more discussion and explanations of every option for finer control.\n",
    "\n",
    "This notebook also demonstrated saving the results of a graph machine learning algorithm back into Neo4j to use for visualisation and other tasks.\n",
    "\n",
    "Revisit this document to use as a reminder.\n",
    "\n",
    "Once you've loaded your data, you can start doing machine learning: a good place to start is the [demo of the GCN algorithm on the Cora dataset for node classification](../node-classification/gcn-node-classification.ipynb). Additionally, StellarGraph includes [many other demos of other algorithms, solving other tasks](../README.md).\n",
    "\n",
    "We also have experimental support for [running some algorithms directly using Neo4j](../connector/neo4j/README.md).\n",
    "\n",
    "(We're still exploring the best ways to have StellarGraph work with Neo4j, so please [let us know](https://github.com/stellargraph/stellargraph#getting-help) your experience of using StellarGraph with Neo4j, both positive and negative.)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "metadata": {},
   "outputs": [],
   "source": [
    "# clean everything up, so that we're not leaving the square graph in the Neo4j instance\n",
    "neo4j_graph.delete_all()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "nbsphinx": "hidden",
    "tags": [
     "CloudRunner"
    ]
   },
   "source": [
    "<table><tr><td>Run the latest release of this notebook:</td><td><a href=\"https://mybinder.org/v2/gh/stellargraph/stellargraph/master?urlpath=lab/tree/demos/basics/loading-saving-neo4j.ipynb\" alt=\"Open In Binder\" target=\"_parent\"><img src=\"https://mybinder.org/badge_logo.svg\"/></a></td><td><a href=\"https://colab.research.google.com/github/stellargraph/stellargraph/blob/master/demos/basics/loading-saving-neo4j.ipynb\" alt=\"Open In Colab\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\"/></a></td></tr></table>"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
